import os
import argparse
from mindspore import context
from mindspore.train.model import Model
from mindspore.context import ParallelMode
import mindspore.nn as nn
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.callback import LossMonitor, TimeMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.common import set_seed, dtype
from src.dataset import create_dataset
from src.loss import SoftmaxCrossEntropyLoss
from src.network import Segmentation, BuildTrainNetwork
from src import learning_rates


def parse_args():
    parser = argparse.ArgumentParser('mindspore NAIC 2020 Remote Sensing training')
    parser.add_argument('--train_dir', type=str, default='', help='where training log and ckpts saved')

    # dataset
    parser.add_argument('--data_url', type=str, default='', help='path of train data')
    parser.add_argument('--batch_size', type=int, default=32, help='batch size')
    parser.add_argument('--crop_size', type=int, default=256, help='crop size')
    parser.add_argument('--image_mean', type=list, default=[0.5, 0.5, 0.5], help='image mean')
    parser.add_argument('--image_std', type=list, default=[0.5, 0.5, 0.5], help='image std')
    parser.add_argument('--min_scale', type=float, default=0.5, help='minimum scale of data argumentation')
    parser.add_argument('--max_scale', type=float, default=2.0, help='maximum scale of data argumentation')
    parser.add_argument('--ignore_label', type=int, default=255, help='ignore label')
    parser.add_argument('--num_classes', type=int, default=13, help='number of classes')

    # optimizer
    parser.add_argument('--train_epochs', type=int, default=100, help='epoch')
    parser.add_argument('--lr_type', type=str, default='cos', help='type of learning rate')
    parser.add_argument('--base_lr', type=float, default=0.015, help='base learning rate')
    parser.add_argument('--lr_decay_step', type=int, default=40000, help='learning rate decay step')
    parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='learning rate decay rate')
    parser.add_argument('--loss_scale', type=float, default=3072.0, help='loss scale')

    # model
    parser.add_argument('--ckpt_pre_trained', type=str, default='', help='pretrained model')

    # train
    parser.add_argument('--is_distributed', action='store_true', help='distributed training')
    parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
    parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
    parser.add_argument('--save_steps', type=int, default=3000, help='steps interval for saving')
    parser.add_argument('--keep_checkpoint_max', type=int, default=5, help='max checkpoint for saving')

    # device
    parser.add_argument('--device', type=str, default='GPU', choices=['GPU', 'Ascend'])

    args, _ = parser.parse_known_args()
    return args


def train(args):
    # init multicards training
    if args.is_distributed:
        init()
        args.rank = get_rank()
        args.group_size = get_group_size()

        parallel_mode = ParallelMode.DATA_PARALLEL
        context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=args.group_size)

    # dataset
    dataset = create_dataset(
        data_dir=args.data_url,
        image_mean=args.image_mean,
        image_std=args.image_std,
        batch_size=args.batch_size,
        crop_size=args.crop_size,
        max_scale=args.max_scale,
        min_scale=args.min_scale,
        ignore_label=args.ignore_label,
        # num_classes=args.num_classes,
        num_readers=2,
        num_parallel_calls=4,
        shard_id=args.rank,
        shard_num=args.group_size,
        train=True,
        repeat=1,
    )
    # dataset = dataset.get_dataset(repeat=1)

    # network
    network = Segmentation(num_classes=args.num_classes)
    # if args.device == 'GPU':
    #     network.to_float(dtype.float16)
    # loss
    loss_ = SoftmaxCrossEntropyLoss(args.num_classes)
    loss_.add_flags_recursive(fp32=True)
    train_net = BuildTrainNetwork(network, loss_)
    train_net.criterion.to_float(dtype.float32)

    # load pretrained model
    if args.ckpt_pre_trained:
        param_dict = load_checkpoint(args.ckpt_pre_trained)
        load_param_into_net(train_net, param_dict)

    # optimizer
    iters_per_epoch = dataset.get_dataset_size()
    total_train_steps = iters_per_epoch * args.train_epochs
    if args.lr_type == 'cos':
        lr_iter = learning_rates.cosine_lr(args.base_lr, total_train_steps, total_train_steps)
    elif args.lr_type == 'poly':
        lr_iter = learning_rates.poly_lr(args.base_lr, total_train_steps, total_train_steps, end_lr=0.0, power=0.9)
    elif args.lr_type == 'exp':
        lr_iter = learning_rates.exponential_lr(args.base_lr, args.lr_decay_step, args.lr_decay_rate,
                                                total_train_steps, staircase=True)
    else:
        raise ValueError('unknown learning rate type')
    opt = nn.Momentum(params=train_net.trainable_params(), learning_rate=lr_iter, momentum=0.9, weight_decay=0.0001,
                      loss_scale=args.loss_scale)

    # loss scale
    manager_loss_scale = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False)
    model = Model(train_net, optimizer=opt, amp_level="O3", loss_scale_manager=manager_loss_scale)

    # callback for saving ckpts
    time_cb = TimeMonitor(data_size=iters_per_epoch)
    loss_cb = LossMonitor()
    cbs = [time_cb, loss_cb]

    if args.rank == 0:
        config_ck = CheckpointConfig(save_checkpoint_steps=args.save_steps,
                                     keep_checkpoint_max=args.keep_checkpoint_max)
        ckpoint_cb = ModelCheckpoint(prefix='naic', directory=args.train_dir, config=config_ck)
        cbs.append(ckpoint_cb)

    model.train(args.train_epochs, dataset, callbacks=cbs)


if __name__ == '__main__':
    set_seed(1)
    args = parse_args()
    if args.device == 'GPU':
        context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
        init("nccl")
    else:
        context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, save_graphs=False,
                            device_target="Ascend", device_id=int(os.getenv('DEVICE_ID')))

    train(args)
